## Validate outcome argument and coerce to canonical string ----------
validate_outcome <- function(outcome){
  match.arg(outcome,
            choices = c("binary", "continuous", "count"))
}

## Convenient aliases for posterior hyper‑parameters ------------------
posterior_params <- function(n, e_prime, a, b, outcome, sigma = 1){
  switch(outcome,
         "binary" = list(a = a + floor(n * e_prime),
                         b = b + n - floor(n * e_prime)),
         "count"  = list(a = a + floor(n * e_prime),
                         b = b + n),
         "continuous" = list(mu = {
           var_post <- 1 / (1 / b^2 + n / sigma^2)
           var_post * (a / b^2 + n * e_prime / sigma^2)
         },
         var = {
           1 / (1 / b^2 + n / sigma^2)
         }))
}

# -------------------------------------------------------------------------
# Posterior probability that θ > θ₀ ---------------------------------------
# -------------------------------------------------------------------------
# outcome = "binary"    -> Beta(a, b)
# outcome = "continuous"-> Normal(μ, σ²)
# outcome = "count"     -> Gamma(a, rate = b)
post_prob_theta <- function(n, theta_star, e, theta0 = 0, a, b, sigma = 1,
                            outcome = c("binary", "continuous", "count")){
  
  outcome  <- validate_outcome(outcome)
  theta    <- theta_star + theta0
  e_prime  <- e + theta0
  params   <- posterior_params(n, e_prime, a, b, outcome, sigma)
  
  switch(outcome,
         "binary" = pbeta(theta,
                          shape1 = params$a,
                          shape2 = params$b,
                          lower.tail = FALSE),
         "continuous" = {
           z <- (theta - params$mu) / sqrt(params$var)
           pnorm(z, lower.tail = FALSE)
         },
         "count" = pgamma(theta,
                          shape = params$a,
                          rate  = params$b,
                          lower.tail = FALSE))
}

# -------------------------------------------------------------------------
# Posterior probability of H₁ (BESS one‑arm) ------------------------------
# -------------------------------------------------------------------------
post_prob_H1 <- function(n, theta_star, e, a, b, 
                         theta0 = 0, sigma = 1, q  = 0.5,
                         outcome = c("binary", "continuous", "count")){
  
  outcome <- validate_outcome(outcome)
  theta   <- theta_star + theta0
  
  p_theta <- post_prob_theta(n, theta_star, e, theta0 = theta0, 
                             a, b, sigma, outcome)
  
  # Prior tail probabilities C1 = P(θ > θ₀ + θ*) under prior
  C1 <- switch(outcome,
               "binary"     = pbeta(theta, a, b, lower.tail = FALSE),
               "continuous" = pnorm(theta, mean = a, sd = b, lower.tail = FALSE),
               "count"      = pgamma(theta, shape = a, rate = b, lower.tail = FALSE))
  C0 <- 1 - C1
  
  if (isTRUE(all.equal(q, 0.5))) {
    (C0 * p_theta) / (C1 + (C0 - C1) * p_theta)
  } else {
    denom <- (1 - q) / C0 + (q / C1 - (1 - q) / C0) * p_theta
    (q / C1 * p_theta) / denom
  }
}

# ------------------------------------------------------------------------
# BESS Algorithm 3 — minimum sample size (n_min) for one-arm trials ------
# ------------------------------------------------------------------------
find_min_n_BESS1 <- function(e, theta_star, a, b, n_max, theta0 = 0,
                             outcome = c("binary", "continuous", "count"),
                             sigma   = NA_real_) {
  
  outcome   <- validate_outcome(outcome)
  theta     <- theta_star +  theta0           # θ* + θ₀
  e_prime   <- e         +  theta0           # e  + θ₀
  n_grid    <- seq_len(n_max)
  
  # ---- continuous ------------------------------------------------------
  if (outcome == "continuous") {
    if (is.na(sigma))
      stop("For a continuous outcome 'sigma' must be supplied (non-NA).")
    
    val <- e - (theta_star + (a / b) * (1 / b + n_grid / sigma^2)^(-1))
    n_min <- which(val >= 0)[1]
    
    if (is.na(n_min))
      stop("Increase `n_max`; target not attained.")
    return(n_min)
  }
  
  # Helper: generic “first-hit” search -----------------------------------
  first_hit <- function(diff_vec) {
    idx <- which(diff_vec > 0)[1]
    if (is.na(idx)){
      return(NA_integer_)
    }else{
      return(idx)
    }
  }
  
  # ---- binary ----------------------------------------------------------
  if (outcome == "binary") {
    
    diff_vec <- exp(lbeta(a + (n_grid + 1) * e_prime,
                          b + (n_grid + 1) * (1 - e_prime)) -
                      lbeta(a +  n_grid      * e_prime,
                            b +  n_grid      * (1 - e_prime))) -
      theta^e_prime * (1 - theta)^(1 - e_prime)
    
    n_min_prime <- first_hit(diff_vec)
    if (is.na(n_min_prime)){
      stop("Increase `n_max`; target not attained.")
    }
    
    # back-search for n_min
    output_n <- NA
    for (n in n_min_prime:0) {
      val <- pbeta(theta, a+n*(e+theta0), b+n*(1-e-theta0)) -
        pbeta(theta, a+(n+1)*(e+theta0), b+(n+1)*(1-e-theta0))
      if (val <= 0 || n == 0){
        #return(n + 1)
        output_n <- n+1
        break
      }
    }
    return(output_n)
  }
  
  # ---- count -----------------------------------------------------------
  if (outcome == "count") {
    
    diff_vec <- exp(lgamma(a+(n_grid+1)*e_prime) - lgamma(a+n_grid*e_prime) +
                      (a+n_grid*e_prime)*log(b+n_grid) -
                      (a+(n_grid+1)*e_prime)*log(b+n_grid+1)) -
      theta^e_prime*exp(-theta)
    
    n_min_prime <- first_hit(diff_vec)
    if (is.na(n_min_prime)){
      stop("Increase `n_max`; target not attained.")
    }
    
    # back-search for n_min
    output_n <- NA
    for (n in n_min_prime:0) {
      val <- pgamma(theta, shape = a+n*(e+theta0), rate = b+n*(1-e-theta0)) -
        pgamma(theta, shape = a+(n+1)*(e+theta0), rate = b+(n+1)*(1-e-theta0))
      if (val <= 0 || n == 0){
        #return(n + 1)
        output_n <- n+1
        break
      }
    }
    return(output_n)
  }
}

# ------------------------------------------------------------------------
# Unit-Testing Find min n ------------------------------------------------
# ------------------------------------------------------------------------

#theta_star <- 0.3
#find_min_n_BESS1(0.4, theta_star, 0.5, 0.5, 500, outcome = "binary")
#find_min_n_BESS1(0.35, theta_star, 0, 10, 500, outcome = "continuous", sigma = 1)
#find_min_n_BESS1(0.375, theta_star, 1, 2, 500, outcome = "count")
#find_min_n_BESS1(0.4, theta_star, 1, 2, 500, outcome = "survival")

# ------------------------------------------------------------------------
# Check if the parameters for count-data outcome satisfies Proposition 3 -
# ------------------------------------------------------------------------
# Parameters:
#   e: evidence; 
#   a, b: hyperparameters; n_max: maximum candidate sample size
check_param_BESS1 <- function(e, a, b){
  term1 <- lgamma(a+2*e) + lgamma(a) - 2*lgamma(a+e)
  term2 <- a*(2*log(b+1) - log(b) - log(b+2))
  term3 <- e*(log(b+1) - log(b+2))
  return(exp(term1+term2+term3) >= 1)
}

# ------------------------------------------------------------------------
# BESS Algorithm 1: Find sample size for one-arm trials ------------------
# ------------------------------------------------------------------------
# Parameters:
#   theta_star: clinically minimum effect size;
#   c: confidence level; e: evidence; a, b: hyperparameters;
#   theta0 = 0: reference response; 
#   outcome = {binary, continuous, count-data};
#   q = 0.5: prior prob of H_1;
#   sig = 1: variance for continuous outcome;
#   n_min, n_max: minimum and maximum candidate sample sizes; 
BESS_one_arm <- function(theta_star, c, e, a, b, theta0 = 0, 
                         outcome = c("binary", "continuous", "count"), 
                         q = 0.5, sig = 1, n_min = 5, n_max = 100){
  
  n_seq <- seq(n_min, n_max)
  
  s_vec <- vapply(n_seq, FUN = post_prob_H1, numeric(1),
    theta_star = theta_star, e = e, a = a, b = b,
    theta0 = theta0, sigma = sig, outcome = outcome, q = q)
  
  if(s_vec[1] >= c){
    return("Can lower n_min.")
  }
  if(s_vec[length(s_vec)] < c){
    return("Need to increase max sample size!")
  }
  
  idx <- which(s_vec >= c)[1]
  n_result <- n_seq[idx]
  
  result <- list(
    n = n_result,
    n_seq = n_seq,
    prob_vec = s_vec
  )
  if(outcome == "binary" || outcome == "count"){
    result$m_vec <- floor(n_seq*e)
    return(result)
  }else{
    return(result)
  }
}

# ------------------------------------------------------------------------
# Unit-Testing One-Arm Cases ---------------------------------------------
# ------------------------------------------------------------------------

#theta_star <- 0.3
#c <- 0.8
#e <- 0.35
#BESS_one_arm(theta_star, c, e, 0.5, 0.5, outcome = "binary",     n_max = 150)$n

#e <- 0.4
#BESS_one_arm(theta_star, c, e, 0,   10,  outcome = "continuous", n_max = 150)$n

#e <- 0.375
#BESS_one_arm(theta_star, c, e, 1,   2,   outcome = "count",      n_max = 150)$n

# ------------------------------------------------------------------------
# Compute Pr(H = H_1|e,n) for two-arm with continuous data ---------------
# ------------------------------------------------------------------------
post_prob_H1_two_cont <- function(n, theta_star, e, a, b, sigma, q = 0.5) {
  
  var_post <- 1 / (1 / b^2 + n / sigma^2)
  mu_post  <- var_post * (a / b^2 + n * e / sigma^2)
  z        <- (theta_star - mu_post) / sqrt(var_post)
  p_theta  <- 1 - pnorm(z)
  
  C1 <- pnorm(theta_star, mean = a, sd = b, lower.tail = FALSE)
  C0 <- 1 - C1
  
  if (isTRUE(all.equal(q, 0.5))) {
    return((C0 * p_theta) / (C1 + (C0 - C1) * p_theta))
  } else {
    denom <- (1 - q) / C0 + (q / C1 - (1 - q) / C0) * p_theta
    return((q / C1 * p_theta) / denom)
  }
}

# ------------------------------------------------------------------------
# BESS Algorithm 1: Find sample size for two-arm continuous trial --------
# ------------------------------------------------------------------------
# Parameters:
#   theta_star: clinically minimum effect size;
#   c: confidence level; e: evidence; a, b: hyperparameters;
#   q = 0.5: prior prob of H_1; sig = 1: variance for continuous outcome;
#   n_min, n_max: minimum and maximum candidate sample sizes; 
BESS_cont <- function(theta_star, c, e, a, b, n_min=5, n_max=100, q = 0.5, sig = 1){
  
  n_seq <- seq(n_min, n_max)
  
  s_vec <- vapply(n_seq, FUN = post_prob_H1_two_cont, numeric(1),
    theta_star = theta_star, e = e, a = a, b = b, sigma = sig, q = q)
  
  if(s_vec[1] >= c){
    return("Can lower n_min.")
  }
  if(s_vec[length(s_vec)] < c){
    return("Need to increase max sample size!")
  }
  
  idx <- which(s_vec >= c)[1]
  result <- list(
    n             = n_seq[idx],
    prob          = s_vec[idx],
    prob_h1_vec   = s_vec
  )
  return(result)
  
}

# ------------------------------------------------------------------------
# Unit-Testing Two-Arm Continuous ----------------------------------------
# ------------------------------------------------------------------------
#theta_star <- 0.05
#c <- 0.8
#e <- 0.15
#BESS_cont(theta_star, c, e, 0, 10, n_max = 150)



